{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Notebook Setup \n",
    "The following cell will install Drake, checkout the underactuated repository, and set up the path (only if necessary).\n",
    "- On Google's Colaboratory, this **will take approximately two minutes** on the first time it runs (to provision the machine), but should only need to reinstall once every 12 hours.  Colab will ask you to \"Reset all runtimes\"; say no to save yourself the reinstall.\n",
    "- On Binder, the machines should already be provisioned by the time you can run this; it should return (almost) instantly.\n",
    "\n",
    "More details are available [here](http://underactuated.mit.edu/drake.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "try:\n",
    "  import pydrake\n",
    "  import underactuated\n",
    "except ImportError:\n",
    "  !curl -s https://raw.githubusercontent.com/RussTedrake/underactuated/master/scripts/setup/jupyter_setup.py > jupyter_setup.py\n",
    "  from jupyter_setup import setup_underactuated\n",
    "  setup_underactuated()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Value Iteration for the Double Integrator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import math\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from mpl_toolkits.mplot3d import Axes3D\n",
    "from matplotlib import cm\n",
    "from IPython import get_ipython\n",
    "\n",
    "from pydrake.all import (\n",
    "  DiagramBuilder, DynamicProgrammingOptions, FittedValueIteration, \n",
    "  LinearSystem, Simulator\n",
    ")\n",
    "from underactuated.double_integrator import DoubleIntegratorVisualizer\n",
    "from underactuated.jupyter import SetupMatplotlibBackend\n",
    "plt_is_interactive = SetupMatplotlibBackend()\n",
    "\n",
    "\n",
    "def DoubleIntegrator():\n",
    "  return LinearSystem(A=np.mat('0 1; 0 0'),\n",
    "                      B=np.mat('0; 1'),\n",
    "                      C=np.eye(2),\n",
    "                      D=np.zeros((2,1)))\n",
    "\n",
    "plant = DoubleIntegrator()\n",
    "simulator = Simulator(plant)\n",
    "options = DynamicProgrammingOptions()\n",
    "\n",
    "qbins = np.linspace(-3., 3., 31)\n",
    "qdotbins = np.linspace(-3., 3., 51)\n",
    "state_grid = [set(qbins), set(qdotbins)]\n",
    "\n",
    "input_limit = 1.\n",
    "input_grid = [set(np.linspace(-input_limit, input_limit, 9))]\n",
    "timestep = 0.01\n",
    "\n",
    "[Q, Qdot] = np.meshgrid(qbins, qdotbins)\n",
    "\n",
    "def draw(iteration, mesh, cost_to_go, policy):\n",
    "    # Drawing is slow, don't draw every frame.\n",
    "    if iteration % 20 != 0:\n",
    "        return\n",
    "    plt.title(\"iteration \" + str(iteration))\n",
    "    J = np.reshape(cost_to_go, Q.shape)\n",
    "    surf = ax1.plot_surface(Q, Qdot, J, rstride=1, cstride=1, cmap=cm.jet)\n",
    "\n",
    "    Pi = np.reshape(policy, Q.shape)\n",
    "    surf2 = ax2.plot_surface(Q, Qdot, Pi, rstride=1, cstride=1, cmap=cm.jet)\n",
    "\n",
    "    if plt.get_backend() != u\"template\":\n",
    "        fig.canvas.draw()\n",
    "        plt.pause(1e-10)\n",
    "\n",
    "    surf.remove()\n",
    "    surf2.remove()\n",
    "\n",
    "if plt_is_interactive:\n",
    "    options.visualization_callback = draw\n",
    "\n",
    "def solve():\n",
    "    policy, cost_to_go = FittedValueIteration(simulator, cost_function, state_grid,\n",
    "                                              input_grid, timestep, options)\n",
    "\n",
    "    J = np.reshape(cost_to_go, Q.shape)\n",
    "    surf = ax1.plot_surface(Q, Qdot, J, rstride=1, cstride=1, cmap=cm.jet)\n",
    "    Pi = np.reshape(policy.get_output_values(), Q.shape)\n",
    "    surf = ax2.plot_surface(Q, Qdot, Pi, rstride=1, cstride=1, cmap=cm.jet)\n",
    "    return policy\n",
    "    \n",
    "def simulate():\n",
    "    # Animate the resulting policy.\n",
    "    builder = DiagramBuilder()\n",
    "    plant = builder.AddSystem(DoubleIntegrator())\n",
    "\n",
    "    vi_policy = builder.AddSystem(policy)\n",
    "    builder.Connect(plant.get_output_port(0), vi_policy.get_input_port(0))\n",
    "    builder.Connect(vi_policy.get_output_port(0), plant.get_input_port(0))\n",
    "\n",
    "    visualizer = builder.AddSystem(DoubleIntegratorVisualizer())\n",
    "    builder.Connect(plant.get_output_port(0), visualizer.get_input_port(0))\n",
    "\n",
    "    diagram = builder.Build()\n",
    "    simulator = Simulator(diagram)\n",
    "\n",
    "    simulator.get_mutable_context().SetContinuousState([-10.0, 0.0])\n",
    "\n",
    "    simulator.AdvanceTo(10. if get_ipython() is not None else 0.1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def min_time_cost(context):\n",
    "    x = context.get_continuous_state_vector().CopyToVector()\n",
    "    if x.dot(x) < .05:\n",
    "        return 0.\n",
    "    return 1.\n",
    "\n",
    "cost_function = min_time_cost\n",
    "options.convergence_tol = 0.001\n",
    "\n",
    "fig = plt.figure(figsize=(9, 4))\n",
    "ax1, ax2 = fig.subplots(1, 2, subplot_kw=dict(projection='3d'))\n",
    "ax1.set_xlabel(\"q\")\n",
    "ax1.set_ylabel(\"qdot\")\n",
    "ax1.set_title(\"Cost-to-Go\")\n",
    "ax2.set_xlabel(\"q\")\n",
    "ax2.set_ylabel(\"qdot\")\n",
    "ax2.set_title(\"Policy\")\n",
    "\n",
    "policy = solve()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simulate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def quadratic_regulator_cost(context):\n",
    "    x = context.get_continuous_state_vector().CopyToVector()\n",
    "    u = plant.EvalVectorInput(context, 0).CopyToVector()\n",
    "    return x.dot(x) + u.dot(u)\n",
    "\n",
    "cost_function = quadratic_regulator_cost\n",
    "options.convergence_tol = 0.1\n",
    "\n",
    "fig = plt.figure(figsize=(9, 4))\n",
    "ax1, ax2 = fig.subplots(1, 2, subplot_kw=dict(projection='3d'))\n",
    "ax1.set_xlabel(\"q\")\n",
    "ax1.set_ylabel(\"qdot\")\n",
    "ax1.set_title(\"Cost-to-Go\")\n",
    "ax2.set_xlabel(\"q\")\n",
    "ax2.set_ylabel(\"qdot\")\n",
    "ax2.set_title(\"Policy\")\n",
    "\n",
    "policy = solve()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "simulate()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
